-
Notifications
You must be signed in to change notification settings - Fork 357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrap ForwardContext around full model forward #789
Conversation
Hello @calpt, |
Hi, However, if I don't check whether the new Maybe use a |
makes sense, done that |
Hello, import transformers
import adapters
import torch, torch.nn as nn
model = transformers.T5ForConditionalGeneration(
transformers.T5Config(
num_layers=2,
num_decoder_layers=2
)
)
adapters.init(model)
model.add_adapter("a")
adapters.ForwardContext.context_args.add("task_ids")
model.generate(input_ids=torch.randint(0, 1000, (3,128)), task_ids=torch.randint(0, 3, (3,))) which returns ---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[20], line 16
13 model.add_adapter("a", overwrite_ok=True)
15 adapters.ForwardContext.context_args.add("task_ids")
---> 16 model.generate(input_ids=torch.randint(0, 1000, (3,128)), task_ids=torch.randint(0, 3, (3,)))
File ~/Documents/labo/etr-peft-composition/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/Documents/labo/etr-peft-composition/src/adapters/hf_transformers/src/transformers/generation/utils.py:2009, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2006 assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
2008 generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
-> 2009 self._validate_model_kwargs(model_kwargs.copy())
2010 self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
2012 # 2. Set generation parameters if not already defined
File ~/Documents/labo/etr-peft-composition/src/adapters/hf_transformers/src/transformers/generation/utils.py:1388, in GenerationMixin._validate_model_kwargs(self, model_kwargs)
1385 unused_model_args.append(key)
1387 if unused_model_args:
-> 1388 raise ValueError(
1389 f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
1390 " generate arguments will also show up in this list)"
1391 )
ValueError: The following `model_kwargs` are not used by the model: ['task_ids'] (note: typos in the generate arguments will also show up in this list) Additional Notes
|
de309e4
to
b38c4a9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
This PR adapts the ForwardContext to be applied to the full model (including head) forward pass. The original base model forward wrapper is now moved to
wrap_base
to make sure no second ForwardContext is created for a single forward pass.This enables passing custom args that are defined in the ForwardContext definition to the top-level model call, as discussed in #783, e.g.:
In the example above, the forward context will automatically add the passed context args as attributes, ie. they can be accessed within the foward pass like this: